<html><!-- Created using the cpp_pretty_printer from the dlib C++ library.  See http://dlib.net for updates. --><head><title>dlib C++ Library - auto.cpp</title></head><body bgcolor='white'><pre>
<font color='#009900'>// Copyright (C) 2018  Davis E. King (davis@dlib.net)
</font><font color='#009900'>// License: Boost Software License   See LICENSE.txt for the full license.
</font><font color='#0000FF'>#ifndef</font> DLIB_AUTO_LEARnING_CPP_
<font color='#0000FF'>#define</font> DLIB_AUTO_LEARnING_CPP_

<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='auto.h.html'>auto.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../global_optimization.h.html'>../global_optimization.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='svm_c_trainer.h.html'>svm_c_trainer.h</a>"

<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>iostream<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>thread<font color='#5555FF'>&gt;</font>

<font color='#0000FF'>namespace</font> dlib
<b>{</b>

    normalized_function<font color='#5555FF'>&lt;</font>decision_function<font color='#5555FF'>&lt;</font>radial_basis_kernel<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> <b><a name='auto_train_rbf_classifier'></a>auto_train_rbf_classifier</b> <font face='Lucida Console'>(</font>
        std::vector<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> x,
        std::vector<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font> y,
        <font color='#0000FF'>const</font> std::chrono::nanoseconds max_runtime,
        <font color='#0000FF'><u>bool</u></font> be_verbose 
    <font face='Lucida Console'>)</font>
    <b>{</b>
        <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> num_positive_training_samples <font color='#5555FF'>=</font> <font color='#BB00BB'>sum</font><font face='Lucida Console'>(</font><font color='#BB00BB'>mat</font><font face='Lucida Console'>(</font>y<font face='Lucida Console'>)</font><font color='#5555FF'>&gt;</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>;
        <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> num_negative_training_samples <font color='#5555FF'>=</font> <font color='#BB00BB'>sum</font><font face='Lucida Console'>(</font><font color='#BB00BB'>mat</font><font face='Lucida Console'>(</font>y<font face='Lucida Console'>)</font><font color='#5555FF'>&lt;</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>;
        <font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>num_positive_training_samples <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font color='#979000'>6</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> num_negative_training_samples <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font color='#979000'>6</font>,
            "<font color='#CC0000'>You must provide at least 6 examples of each class to this training routine.</font>"<font face='Lucida Console'>)</font>;
        <font color='#009900'>// make sure requires clause is not broken
</font>        <font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>is_binary_classification_problem</font><font face='Lucida Console'>(</font>x,y<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font>,
            "<font color='#CC0000'>\tdecision_function svm_c_trainer::train(x,y)</font>"
            <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t invalid inputs were given to this function</font>"
            <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t x.size(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> x.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> 
            <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t y.size(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> y.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> 
            <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t is_binary_classification_problem(x,y): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>is_binary_classification_problem</font><font face='Lucida Console'>(</font>x,y<font face='Lucida Console'>)</font>
        <font face='Lucida Console'>)</font>;


        <font color='#BB00BB'>randomize_samples</font><font face='Lucida Console'>(</font>x,y<font face='Lucida Console'>)</font>;

        vector_normalizer<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> normalizer;
        <font color='#009900'>// let the normalizer learn the mean and standard deviation of the samples
</font>        normalizer.<font color='#BB00BB'>train</font><font face='Lucida Console'>(</font>x<font face='Lucida Console'>)</font>;
        <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> samp : x<font face='Lucida Console'>)</font>
            samp <font color='#5555FF'>=</font> <font color='#BB00BB'>normalizer</font><font face='Lucida Console'>(</font>samp<font face='Lucida Console'>)</font>;


        normalized_function<font color='#5555FF'>&lt;</font>decision_function<font color='#5555FF'>&lt;</font>radial_basis_kernel<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> df;
        df.normalizer <font color='#5555FF'>=</font> normalizer;

        <font color='#0000FF'>typedef</font> radial_basis_kernel<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> kernel_type;

        std::mutex m;
        <font color='#0000FF'>auto</font> cross_validation_score <font color='#5555FF'>=</font> [<font color='#5555FF'>&amp;</font>]<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> gamma, <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> c1, <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> c2<font face='Lucida Console'>)</font> 
        <b>{</b>
            svm_c_trainer<font color='#5555FF'>&lt;</font>kernel_type<font color='#5555FF'>&gt;</font> trainer;
            trainer.<font color='#BB00BB'>set_kernel</font><font face='Lucida Console'>(</font><font color='#BB00BB'>kernel_type</font><font face='Lucida Console'>(</font>gamma<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
            trainer.<font color='#BB00BB'>set_c_class1</font><font face='Lucida Console'>(</font>c1<font face='Lucida Console'>)</font>;
            trainer.<font color='#BB00BB'>set_c_class2</font><font face='Lucida Console'>(</font>c2<font face='Lucida Console'>)</font>;

            <font color='#009900'>// Finally, perform 6-fold cross validation and then print and return the results.
</font>            matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font> result <font color='#5555FF'>=</font> <font color='#BB00BB'>cross_validate_trainer</font><font face='Lucida Console'>(</font>trainer, x, y, <font color='#979000'>6</font><font face='Lucida Console'>)</font>;
            <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>be_verbose<font face='Lucida Console'>)</font>
            <b>{</b>
                std::lock_guard<font color='#5555FF'>&lt;</font>std::mutex<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>lock</font><font face='Lucida Console'>(</font>m<font face='Lucida Console'>)</font>;
                std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>gamma: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::<font color='#BB00BB'>setw</font><font face='Lucida Console'>(</font><font color='#979000'>11</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>  c1: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::<font color='#BB00BB'>setw</font><font face='Lucida Console'>(</font><font color='#979000'>11</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> c1 <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>  "<font color='#CC0000'>  c2: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::<font color='#BB00BB'>setw</font><font face='Lucida Console'>(</font><font color='#979000'>11</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> c2 <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>  "<font color='#CC0000'>  cross validation accuracy: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> result <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::flush;
            <b>}</b>

            <font color='#009900'>// return the f1 score plus a penalty for picking large parameter settings
</font>            <font color='#009900'>// since those are, a priori less likely to generalize.
</font>            <font color='#0000FF'>return</font> <font color='#979000'>2</font><font color='#5555FF'>*</font><font color='#BB00BB'>prod</font><font face='Lucida Console'>(</font>result<font face='Lucida Console'>)</font><font color='#5555FF'>/</font><font color='#BB00BB'>sum</font><font face='Lucida Console'>(</font>result<font face='Lucida Console'>)</font> <font color='#5555FF'>-</font> std::<font color='#BB00BB'>max</font><font face='Lucida Console'>(</font>c1,c2<font face='Lucida Console'>)</font><font color='#5555FF'>/</font><font color='#979000'>1e12</font> <font color='#5555FF'>-</font> gamma<font color='#5555FF'>/</font><font color='#979000'>1e8</font>;
        <b>}</b>;


        <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>be_verbose<font face='Lucida Console'>)</font>
            std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Searching for best RBF-SVM training parameters...</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::endl;
        <font color='#0000FF'>auto</font> result <font color='#5555FF'>=</font> <font color='#BB00BB'>find_max_global</font><font face='Lucida Console'>(</font>
            <font color='#BB00BB'>default_thread_pool</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
            cross_validation_score, 
            <b>{</b><font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>5</font>, <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>5</font>, <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>5</font><b>}</b>,  <font color='#009900'>// lower bound constraints on gamma, c1, and c2, respectively
</font>            <b>{</b><font color='#979000'>100</font>,  <font color='#979000'>1e6</font>,  <font color='#979000'>1e6</font><b>}</b>,   <font color='#009900'>// upper bound constraints on gamma, c1, and c2, respectively
</font>            max_runtime<font face='Lucida Console'>)</font>;

        <font color='#0000FF'><u>double</u></font> best_gamma <font color='#5555FF'>=</font> result.<font color='#BB00BB'>x</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>;
        <font color='#0000FF'><u>double</u></font> best_c1    <font color='#5555FF'>=</font> result.<font color='#BB00BB'>x</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>;
        <font color='#0000FF'><u>double</u></font> best_c2    <font color='#5555FF'>=</font> result.<font color='#BB00BB'>x</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font>;

        <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>be_verbose<font face='Lucida Console'>)</font>
        <b>{</b>
            std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> best cross-validation score: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> result.y <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::endl;
            std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> best gamma: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> best_gamma <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>   best c1: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> best_c1 <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>    best c2: </font>"<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> best_c2  <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::endl;
        <b>}</b>

        svm_c_trainer<font color='#5555FF'>&lt;</font>kernel_type<font color='#5555FF'>&gt;</font> trainer;
        trainer.<font color='#BB00BB'>set_kernel</font><font face='Lucida Console'>(</font><font color='#BB00BB'>kernel_type</font><font face='Lucida Console'>(</font>best_gamma<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
        trainer.<font color='#BB00BB'>set_c_class1</font><font face='Lucida Console'>(</font>best_c1<font face='Lucida Console'>)</font>;
        trainer.<font color='#BB00BB'>set_c_class2</font><font face='Lucida Console'>(</font>best_c2<font face='Lucida Console'>)</font>;

        <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>be_verbose<font face='Lucida Console'>)</font>
            std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Training final classifier with best parameters...</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::endl;

        df.function <font color='#5555FF'>=</font> trainer.<font color='#BB00BB'>train</font><font face='Lucida Console'>(</font>x,y<font face='Lucida Console'>)</font>;

        <font color='#0000FF'>return</font> df;
    <b>}</b>
<b>}</b>

<font color='#0000FF'>#endif</font> <font color='#009900'>// DLIB_AUTO_LEARnING_CPP_
</font>


</pre></body></html>